{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import collections\n",
    "import os\n",
    "import random\n",
    "from pathlib import Path\n",
    "import logging\n",
    "import shutil\n",
    "import time\n",
    "from packaging import version\n",
    "from collections import defaultdict\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import gzip\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "import torch.distributed as dist\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "from src.param import parse_args\n",
    "from src.utils import LossMeter\n",
    "from src.dist_utils import reduce_dict\n",
    "from transformers import T5Tokenizer, T5TokenizerFast\n",
    "from src.tokenization import P5Tokenizer, P5TokenizerFast\n",
    "from src.pretrain_model import P5Pretraining\n",
    "\n",
    "_use_native_amp = False\n",
    "_use_apex = False\n",
    "\n",
    "# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex\n",
    "if version.parse(torch.__version__) < version.parse(\"1.6\"):\n",
    "    from transormers.file_utils import is_apex_available\n",
    "    if is_apex_available():\n",
    "        from apex import amp\n",
    "    _use_apex = True\n",
    "else:\n",
    "    _use_native_amp = True\n",
    "    from torch.cuda.amp import autocast\n",
    "\n",
    "from src.trainer_base import TrainerBase\n",
    "\n",
    "import pickle\n",
    "\n",
    "def load_pickle(filename):\n",
    "    with open(filename, \"rb\") as f:\n",
    "        return pickle.load(f)\n",
    "\n",
    "\n",
    "def save_pickle(data, filename):\n",
    "    with open(filename, \"wb\") as f:\n",
    "        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "        \n",
    "import json\n",
    "\n",
    "def load_json(file_path):\n",
    "    with open(file_path, \"r\") as f:\n",
    "        return json.load(f)\n",
    "    \n",
    "def ReadLineFromFile(path):\n",
    "    lines = []\n",
    "    with open(path,'r') as fd:\n",
    "        for line in fd:\n",
    "            lines.append(line.rstrip('\\n'))\n",
    "    return lines\n",
    "\n",
    "def parse(path):\n",
    "    g = gzip.open(path, 'r')\n",
    "    for l in g:\n",
    "        yield eval(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['rating_loss', 'sequential_loss', 'explanation_loss', 'review_loss', 'traditional_loss']\n",
      "Process Launching at GPU 1\n",
      "{'distributed': False, 'multiGPU': True, 'fp16': True, 'train': 'yelp', 'valid': 'yelp', 'test': 'yelp', 'batch_size': 16, 'optim': 'adamw', 'warmup_ratio': 0.02, 'lr': 0.001, 'num_workers': 4, 'clip_grad_norm': 1.0, 'losses': 'rating,sequential,explanation,review,traditional', 'backbone': 't5-small', 'output': 'snap/yelp-small', 'epoch': 10, 'local_rank': 0, 'comment': '', 'train_topk': -1, 'valid_topk': -1, 'dropout': 0.1, 'tokenizer': 'p5', 'max_text_length': 512, 'do_lower_case': False, 'word_mask_rate': 0.15, 'gen_max_length': 64, 'weight_decay': 0.01, 'adam_eps': 1e-06, 'gradient_accumulation_steps': 1, 'seed': 2022, 'whole_word_embed': True, 'world_size': 4, 'LOSSES_NAME': ['rating_loss', 'sequential_loss', 'explanation_loss', 'review_loss', 'traditional_loss', 'total_loss'], 'gpu': 1, 'rank': 1}\n"
     ]
    }
   ],
   "source": [
    "class DotDict(dict):\n",
    "    def __init__(self, **kwds):\n",
    "        self.update(kwds)\n",
    "        self.__dict__ = self\n",
    "        \n",
    "args = DotDict()\n",
    "\n",
    "args.distributed = False\n",
    "args.multiGPU = True\n",
    "args.fp16 = True\n",
    "args.train = \"yelp\"\n",
    "args.valid = \"yelp\"\n",
    "args.test = \"yelp\"\n",
    "args.batch_size = 16\n",
    "args.optim = 'adamw' \n",
    "args.warmup_ratio = 0.02\n",
    "args.lr = 1e-3\n",
    "args.num_workers = 4\n",
    "args.clip_grad_norm = 1.0\n",
    "args.losses = 'rating,sequential,explanation,review,traditional'\n",
    "args.backbone = 't5-small' # small or base\n",
    "args.output = 'snap/yelp-small'\n",
    "args.epoch = 10\n",
    "args.local_rank = 0\n",
    "\n",
    "args.comment = ''\n",
    "args.train_topk = -1\n",
    "args.valid_topk = -1\n",
    "args.dropout = 0.1\n",
    "\n",
    "args.tokenizer = 'p5'\n",
    "args.max_text_length = 512\n",
    "args.do_lower_case = False\n",
    "args.word_mask_rate = 0.15\n",
    "args.gen_max_length = 64\n",
    "\n",
    "args.weight_decay = 0.01\n",
    "args.adam_eps = 1e-6\n",
    "args.gradient_accumulation_steps = 1\n",
    "\n",
    "'''\n",
    "Set seeds\n",
    "'''\n",
    "args.seed = 2022\n",
    "torch.manual_seed(args.seed)\n",
    "random.seed(args.seed)\n",
    "np.random.seed(args.seed)\n",
    "\n",
    "'''\n",
    "Whole word embedding\n",
    "'''\n",
    "args.whole_word_embed = True\n",
    "\n",
    "cudnn.benchmark = True\n",
    "ngpus_per_node = torch.cuda.device_count()\n",
    "args.world_size = ngpus_per_node\n",
    "\n",
    "LOSSES_NAME = [f'{name}_loss' for name in args.losses.split(',')]\n",
    "if args.local_rank in [0, -1]:\n",
    "    print(LOSSES_NAME)\n",
    "LOSSES_NAME.append('total_loss') # total loss\n",
    "\n",
    "args.LOSSES_NAME = LOSSES_NAME\n",
    "\n",
    "gpu = 1 #args.local_rank\n",
    "args.gpu = gpu\n",
    "args.rank = gpu\n",
    "print(f'Process Launching at GPU {gpu}')\n",
    "\n",
    "torch.cuda.set_device('cuda:{}'.format(gpu))\n",
    "\n",
    "comments = []\n",
    "dsets = []\n",
    "if 'yelp' in args.train:\n",
    "    dsets.append('yelp')\n",
    "comments.append(''.join(dsets))\n",
    "if args.backbone:\n",
    "    comments.append(args.backbone)\n",
    "comments.append(''.join(args.losses.split(',')))\n",
    "if args.comment != '':\n",
    "    comments.append(args.comment)\n",
    "comment = '_'.join(comments)\n",
    "\n",
    "if args.local_rank in [0, -1]:\n",
    "    print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_config(args):\n",
    "    from transformers import T5Config, BartConfig\n",
    "\n",
    "    if 't5' in args.backbone:\n",
    "        config_class = T5Config\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "    config = config_class.from_pretrained(args.backbone)\n",
    "    config.dropout_rate = args.dropout\n",
    "    config.dropout = args.dropout\n",
    "    config.attention_dropout = args.dropout\n",
    "    config.activation_dropout = args.dropout\n",
    "    config.losses = args.losses\n",
    "\n",
    "    return config\n",
    "\n",
    "\n",
    "def create_tokenizer(args):\n",
    "    from transformers import T5Tokenizer, T5TokenizerFast\n",
    "    from src.tokenization import P5Tokenizer, P5TokenizerFast\n",
    "\n",
    "    if 'p5' in args.tokenizer:\n",
    "        tokenizer_class = P5Tokenizer\n",
    "\n",
    "    tokenizer_name = args.backbone\n",
    "    \n",
    "    tokenizer = tokenizer_class.from_pretrained(\n",
    "        tokenizer_name,\n",
    "        max_length=args.max_text_length,\n",
    "        do_lower_case=args.do_lower_case,\n",
    "    )\n",
    "\n",
    "    print(tokenizer_class, tokenizer_name)\n",
    "    \n",
    "    return tokenizer\n",
    "\n",
    "\n",
    "def create_model(model_class, config=None):\n",
    "    print(f'Building Model at GPU {args.gpu}')\n",
    "\n",
    "    model_name = args.backbone\n",
    "\n",
    "    model = model_class.from_pretrained(\n",
    "        model_name,\n",
    "        config=config\n",
    "    )\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'src.tokenization.P5Tokenizer'> t5-small\n",
      "Building Model at GPU 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of P5Pretraining were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.whole_word_embeddings.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "config = create_config(args)\n",
    "\n",
    "if args.tokenizer is None:\n",
    "    args.tokenizer = args.backbone\n",
    "    \n",
    "tokenizer = create_tokenizer(args)\n",
    "\n",
    "model_class = P5Pretraining\n",
    "model = create_model(model_class, config)\n",
    "\n",
    "model = model.cuda()\n",
    "\n",
    "if 'p5' in args.tokenizer:\n",
    "    model.resize_token_embeddings(tokenizer.vocab_size)\n",
    "    \n",
    "model.tokenizer = tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded from  ../snap/yelp-small.pth\n",
      "<All keys matched successfully>\n"
     ]
    }
   ],
   "source": [
    "args.load = \"../snap/yelp-small.pth\"\n",
    "\n",
    "# Load Checkpoint\n",
    "from src.utils import load_state_dict, LossMeter, set_global_logging_level\n",
    "from pprint import pprint\n",
    "\n",
    "def load_checkpoint(ckpt_path):\n",
    "    state_dict = load_state_dict(ckpt_path, 'cpu')\n",
    "    results = model.load_state_dict(state_dict, strict=False)\n",
    "    print('Model loaded from ', ckpt_path)\n",
    "    pprint(results)\n",
    "\n",
    "ckpt_path = args.load\n",
    "load_checkpoint(ckpt_path)\n",
    "\n",
    "from src.all_yelp_templates import all_tasks as task_templates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Check Test Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_splits = load_pickle('../data/yelp/rating_splits_augmented.pkl')\n",
    "test_review_data = data_splits['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "31636"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(test_review_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'review_id': 'Vx-DPb4olxt1Zxf9dI8b2A',\n",
       " 'useful': 0,\n",
       " 'funny': 0,\n",
       " 'cool': 0,\n",
       " 'date': '2019-01-02 19:11:49',\n",
       " 'unixReviewTime': '20190102191149',\n",
       " 'reviewerID': 'Donht4mLJ4aO4FQhqHGJtw',\n",
       " 'asin': '5eV8oUGdBXylwB7HeaDFOA',\n",
       " 'overall': 4.0,\n",
       " 'reviewText': 'Top notch cuisine and very friendly service. Everything here tastes very fresh. The rice pilaf is second to none!! Prices are very reasonable for the high quality food they serve!!',\n",
       " 'explanation': 'Everything here tastes very fresh',\n",
       " 'feature': 'taste'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_review_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30431\n",
      "20033\n"
     ]
    }
   ],
   "source": [
    "data_maps = load_json(os.path.join('../data', 'yelp', 'datamaps.json'))\n",
    "print(len(data_maps['user2id'])) # number of users\n",
    "print(len(data_maps['item2id'])) # number of items"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test P5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset, Sampler\n",
    "from src.pretrain_data import get_loader\n",
    "from evaluate.utils import rouge_score, bleu_score, unique_sentence_percent, root_mean_square_error, mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity\n",
    "from evaluate.metrics4rec import evaluate_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation - Rating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1978\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1978it [02:11, 14.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  1.4685\n",
      "MAE  1.0054\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'rating': ['1-10'] # or '1-6'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), float(p)) for (r, p) in zip(gt_ratings, pred_ratings) if p in [str(i/10.0) for i in list(range(10, 50))]]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1978\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1978it [01:48, 18.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  1.4868\n",
      "MAE  1.0186\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'rating': ['1-6'] # or '1-10'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), float(p)) for (r, p) in zip(gt_ratings, pred_ratings) if p in [str(i/10.0) for i in list(range(10, 50))]]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Evaluation - Sequential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]/global/homes/z/zw241/.conda/envs/pt-1.10/lib/python3.9/site-packages/transformers/generation_utils.py:1632: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  next_indices = next_tokens // vocab_size\n",
      "1902it [26:00,  1.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.0403\t0.0574\t0.0574\t0.0115\t0.0346\t0.0346\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.0445\t0.0703\t0.0703\t0.0070\t0.0364\t0.0364\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.0445\\t0.0703\\t0.0703\\t0.0070\\t0.0364\\t0.0364',\n",
       " {'ndcg': 0.04450006114321309,\n",
       "  'map': 0.03635002005578063,\n",
       "  'recall': 0.07032302586178568,\n",
       "  'precision': 0.007032302586178309,\n",
       "  'mrr': 0.03635002005578063,\n",
       "  'hit': 0.07032302586178568})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'sequential': ['2-13'] # or '2-3'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1902it [25:20,  1.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.0402\t0.0568\t0.0568\t0.0114\t0.0347\t0.0347\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.0447\t0.0707\t0.0707\t0.0071\t0.0365\t0.0365\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.0447\\t0.0707\\t0.0707\\t0.0071\\t0.0365\\t0.0365',\n",
       " {'ndcg': 0.044696482943397606,\n",
       "  'map': 0.03653161745567516,\n",
       "  'recall': 0.07065163813216785,\n",
       "  'precision': 0.007065163813216525,\n",
       "  'mrr': 0.03653161745567516,\n",
       "  'hit': 0.07065163813216785})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'sequential': ['2-3'] # or '2-13'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation - Explanation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1529it [15:42,  1.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU-1 21.3121\n",
      "BLEU-4  3.0584\n",
      "rouge_1/f_score 27.3011\n",
      "rouge_1/r_score 28.7363\n",
      "rouge_1/p_score 30.5548\n",
      "rouge_2/f_score  6.7639\n",
      "rouge_2/r_score  7.7766\n",
      "rouge_2/p_score  7.1706\n",
      "rouge_l/f_score 19.6315\n",
      "rouge_l/r_score 24.0215\n",
      "rouge_l/p_score 23.4873\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'explanation': ['3-10'] # or '3-7' or '3-2'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "tokens_predict = []\n",
    "tokens_test = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                min_length=10,\n",
    "                num_beams=12,\n",
    "                num_return_sequences=1,\n",
    "                num_beam_groups=3\n",
    "        )\n",
    "        results = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "        tokens_predict.extend(results) \n",
    "        tokens_test.extend(batch['target_text'])\n",
    "        \n",
    "new_tokens_predict = [l.split() for l in tokens_predict]\n",
    "new_tokens_test = [ll.split() for ll in tokens_test]\n",
    "BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False)\n",
    "BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False)\n",
    "ROUGE = rouge_score(tokens_test, tokens_predict)\n",
    "\n",
    "print('BLEU-1 {:7.4f}'.format(BLEU1))\n",
    "print('BLEU-4 {:7.4f}'.format(BLEU4))\n",
    "for (k, v) in ROUGE.items():\n",
    "    print('{} {:7.4f}'.format(k, v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1529it [13:12,  1.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU-1 21.2426\n",
      "BLEU-4  2.9767\n",
      "rouge_1/f_score 27.1814\n",
      "rouge_1/r_score 28.6122\n",
      "rouge_1/p_score 30.3839\n",
      "rouge_2/f_score  6.6840\n",
      "rouge_2/r_score  7.6819\n",
      "rouge_2/p_score  7.0891\n",
      "rouge_l/f_score 19.6292\n",
      "rouge_l/r_score 23.9850\n",
      "rouge_l/p_score 23.5445\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'explanation': ['3-7'] # or '3-10' or '3-2'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "tokens_predict = []\n",
    "tokens_test = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                min_length=10,\n",
    "                num_beams=12,\n",
    "                num_return_sequences=1,\n",
    "                num_beam_groups=3\n",
    "        )\n",
    "        results = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "        tokens_predict.extend(results) \n",
    "        tokens_test.extend(batch['target_text'])\n",
    "        \n",
    "new_tokens_predict = [l.split() for l in tokens_predict]\n",
    "new_tokens_test = [ll.split() for ll in tokens_test]\n",
    "BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False)\n",
    "BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False)\n",
    "ROUGE = rouge_score(tokens_test, tokens_predict)\n",
    "\n",
    "print('BLEU-1 {:7.4f}'.format(BLEU1))\n",
    "print('BLEU-4 {:7.4f}'.format(BLEU4))\n",
    "for (k, v) in ROUGE.items():\n",
    "    print('{} {:7.4f}'.format(k, v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1529it [04:15,  5.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU-1 14.0830\n",
      "BLEU-4  1.2742\n",
      "rouge_1/f_score 18.1816\n",
      "rouge_1/r_score 18.2451\n",
      "rouge_1/p_score 21.1176\n",
      "rouge_2/f_score  2.9478\n",
      "rouge_2/r_score  3.3080\n",
      "rouge_2/p_score  3.1611\n",
      "rouge_l/f_score 13.5521\n",
      "rouge_l/r_score 15.8694\n",
      "rouge_l/p_score 16.6549\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'explanation': ['3-2']\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "tokens_predict = []\n",
    "tokens_test = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                min_length=8\n",
    "        )\n",
    "        results = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "        tokens_predict.extend(results) \n",
    "        tokens_test.extend(batch['target_text'])\n",
    "        \n",
    "new_tokens_predict = [l.split() for l in tokens_predict]\n",
    "new_tokens_test = [ll.split() for ll in tokens_test]\n",
    "BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False)\n",
    "BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False)\n",
    "ROUGE = rouge_score(tokens_test, tokens_predict)\n",
    "\n",
    "print('BLEU-1 {:7.4f}'.format(BLEU1))\n",
    "print('BLEU-4 {:7.4f}'.format(BLEU4))\n",
    "for (k, v) in ROUGE.items():\n",
    "    print('{} {:7.4f}'.format(k, v))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Evaluation - Review"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1978\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "51it [00:03, 13.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  0.6350\n",
      "MAE  0.3150\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'review': ['4-3'] # or '4-2'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    if i > 50:\n",
    "        break\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), round(float(p))) for (r, p) in zip(gt_ratings, pred_ratings)]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1978\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "51it [00:03, 13.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  0.6301\n",
      "MAE  0.3113\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'review': ['4-2'] # or '4-3'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    if i > 50:\n",
    "        break\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), round(float(p))) for (r, p) in zip(gt_ratings, pred_ratings)]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Evaluation - Traditional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1902it [27:03,  1.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@1\tRec@1\tHits@1\tPrec@1\tMAP@1\tMRR@1\n",
      "0.0648\t0.0648\t0.0648\t0.0648\t0.0648\t0.0648\n",
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.1203\t0.1729\t0.1729\t0.0346\t0.1030\t0.1030\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.1464\t0.2543\t0.2543\t0.0254\t0.1136\t0.1136\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.1464\\t0.2543\\t0.2543\\t0.0254\\t0.1136\\t0.1136',\n",
       " {'ndcg': 0.14637761309667505,\n",
       "  'map': 0.11358604399335892,\n",
       "  'recall': 0.25434589727580426,\n",
       "  'precision': 0.02543458972758396,\n",
       "  'mrr': 0.11358604399335892,\n",
       "  'hit': 0.25434589727580426})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'traditional': ['5-8']  # or '5-5'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 1)\n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['yelp']\n",
      "compute_datum_info\n",
      "1902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1902it [26:32,  1.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@1\tRec@1\tHits@1\tPrec@1\tMAP@1\tMRR@1\n",
      "0.0658\t0.0658\t0.0658\t0.0658\t0.0658\t0.0658\n",
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.1223\t0.1765\t0.1765\t0.0353\t0.1046\t0.1046\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.1492\t0.2603\t0.2603\t0.0260\t0.1155\t0.1155\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.1492\\t0.2603\\t0.2603\\t0.0260\\t0.1155\\t0.1155',\n",
       " {'ndcg': 0.14922877417860922,\n",
       "  'map': 0.11551674540320987,\n",
       "  'recall': 0.26032664059675986,\n",
       "  'precision': 0.026032664059679654,\n",
       "  'mrr': 0.11551674540320987,\n",
       "  'hit': 0.26032664059675986})"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'traditional': ['5-5']  # or '5-8'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 1)\n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt-1.10",
   "language": "python",
   "name": "pt-1.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
